{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义优化\n",
    "\n",
    "Apache TVM 主要设计目标是使优化流水线易于定制，无论是研究或开发目的，还是迭代工程优化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/_optional_torch_c_dlpack.py:409: UserWarning: Failed to load torch c dlpack extension: Error building extension 'c_dlpack': [1/2] /media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1018\\\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o \n",
      "\u001b[31mFAILED: [code=1] \u001b[0mmain.o \n",
      "/media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1018\\\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o \n",
      "In file included from /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp:8:\n",
      "/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/c10/cuda/CUDAStream.h:3:10: fatal error: cuda_runtime_api.h: No such file or directory\n",
      "    3 | #include <cuda_runtime_api.h>\n",
      "      |          ^~~~~~~~~~~~~~~~~~~~\n",
      "compilation terminated.\n",
      "ninja: build stopped: subcommand failed.\n",
      ",EnvTensorAllocator will not be enabled.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import tempfile\n",
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import IRModule, relax\n",
    "from tvm.relax.frontend import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 可组合IRModule优化\n",
    "\n",
    "Apache TVM提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有流水线组合。请注意，每个优化可以聚焦于 **部分计算图**，实现局部 lower 或者局部优化。\n",
    "\n",
    "## 准备 Relax 模块\n",
    "\n",
    "首先准备 Relax 模块。这个模块可以从其他框架导入，用 NN 模块前端或 TVMScript 构建。这里使用简单的神经网络模型作为例子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">forward</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            permute_dims: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc1_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(x, permute_dims, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            add: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul, fc1_bias)\n",
       "            relu: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(add)\n",
       "            permute_dims1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc2_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(relu, permute_dims1, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> matmul1\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "class RelaxModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(RelaxModel, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 256)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(256, 10, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "input_shape = (1, 784)\n",
    "mod, params = RelaxModel().export_tvm({\"forward\": {\"x\": nn.spec.Tensor(input_shape, \"float32\")}})\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 库调度\n",
    "\n",
    "希望快速尝试针对特定平台（例如 GPU）的变体库优化。可以为特定平台和算子编写特定的调度过程。这里展示如何为某些模式调度 CUBLAS 库。\n",
    "\n",
    "```{note}\n",
    "本教程仅演示了针对 CUBLAS 的单个算子调度，突出显示了优化流水线的灵活性。在真实案例中，可以导入多个模式并将它们调度到不同的内核。\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "# Import cublas pattern\n",
    "import tvm.relax.backend.contrib.cublas as _cublas\n",
    "\n",
    "\n",
    "# Define a new pass for CUBLAS dispatch\n",
    "@tvm.transform.module_pass(opt_level=0, name=\"CublasDispatch\")\n",
    "class CublasDispatch:\n",
    "    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n",
    "        # Check if CUBLAS is enabled\n",
    "        if not tvm.get_global_func(\"relax.ext.cublas\", True):\n",
    "            raise Exception(\"CUBLAS is not enabled.\")\n",
    "\n",
    "        # Get interested patterns\n",
    "        patterns = [relax.backend.get_pattern(\"cublas.matmul_transposed_bias_relu\")]\n",
    "        # Note in real-world cases, we usually get all patterns\n",
    "        # patterns = relax.backend.get_patterns_with_prefix(\"cublas\")\n",
    "\n",
    "        # Fuse ops by patterns and then run codegen\n",
    "        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)\n",
    "        mod = relax.transform.RunCodegen()(mod)\n",
    "        return mod\n",
    "\n",
    "\n",
    "mod = CublasDispatch()(mod)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 调度过程之后\n",
    "\n",
    "可以看到第一个 ``nn.Linear`` 和 ``nn.ReLU`` 被融合并重写为 ``call_dps_packed`` 函数，该函数调用 CUBLAS 库。值得注意的是，其他部分没有改变，这意味着我们可以有选择地为某些计算调度优化。\n",
    "\n",
    "## 自动调优\n",
    "\n",
    "在之前的例子基础上，可以通过自动调优进一步优化模型的 **其余计算部分**。这里我们展示如何使用元调度来自动调优模型。\n",
    "\n",
    "可以使用 ``MetaScheduleTuneTIR`` 过程来简化模型调优，而 ``MetaScheduleApplyDatabase`` 过程则将最佳配置应用到模型上。调优过程将生成搜索空间，调优模型，接下来的步骤将把最佳配置应用到模型上。在运行这些过程之前，需要通过 ``LegalizeOps`` 将 Relax 算子降低为 TensorIR 函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "device = tvm.cuda(0)\n",
    "target = tvm.target.Target.from_device(device)\n",
    "if os.getenv(\"CI\", \"\") != \"true\":\n",
    "    trials = 2000\n",
    "    with target, tempfile.TemporaryDirectory() as tmp_dir:\n",
    "        mod = tvm.ir.transform.Sequential(\n",
    "            [\n",
    "                relax.get_pipeline(\"zero\"),\n",
    "                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),\n",
    "                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),\n",
    "            ]\n",
    "        )(mod)\n",
    "\n",
    "    mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DLight 规则\n",
    "\n",
    "DLight 规则是一组用于调度和优化内核的默认规则。DLight规则旨在实现快速编译和**公平**的性能。在某些情况下，例如语言模型，DLight提供出色的性能，而对于通用模型，它在性能和编译时间之间取得平衡。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "from tvm import dlight as dl\n",
    "\n",
    "# Apply DLight rules\n",
    "with target:\n",
    "    mod = tvm.ir.transform.Sequential(\n",
    "        [\n",
    "            relax.get_pipeline(\"zero\"),\n",
    "            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable\n",
    "                dl.gpu.Matmul(),\n",
    "                dl.gpu.GEMV(),\n",
    "                dl.gpu.Reduction(),\n",
    "                dl.gpu.GeneralReduction(),\n",
    "                dl.gpu.Fallback(),\n",
    "            ),\n",
    "        ]\n",
    "    )(mod)\n",
    "\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "本教程重点在于展示优化流水线的演示，而不是将性能推向极限。当前的优化可能不是最佳的。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 部署优化后的模型\n",
    "\n",
    "可以构建并将优化后的模型部署到 TVM 运行时。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ex = tvm.compile(mod, target=\"cuda\")\n",
    "dev = tvm.device(\"cuda\", 0)\n",
    "vm = relax.VirtualMachine(ex, dev)\n",
    "# Need to allocate data and params on GPU device\n",
    "data = tvm.runtime.tensor(np.random.rand(*input_shape).astype(\"float32\"), dev)\n",
    "gpu_params = [tvm.runtime.tensor(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]\n",
    "gpu_out = vm[\"forward\"](data, *gpu_params).numpy()\n",
    "print(gpu_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "\n",
    "本教程展示了如何为 Apache TVM 中的机器学习模型自定义优化流水线。我们可以容易地组合优化过程，并为计算图的不同部分自定义优化。优化流水线的灵活性使我们能够快速迭代优化并提高模型性能。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py313",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
